{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6a627acd-442f-4cd5-a319-6258cc34edd4",
   "metadata": {},
   "source": [
    "# Physics Simulation of 3D Gaussian Splats Using Simplicits - Now With Collisions!\n",
    "\n",
    "Let's simulate 3D Gaussian Splat objects using [Simplicits](https://research.nvidia.com/labs/toronto-ai/simplicits/), fully integrated into the Kaolin Library. We will be able to set up and interactively view the simulation directly in this Jupyter notebook.\n",
    "\n",
    "With v0.18.0, Kaolin also supports collision handling between objects, which we will also show here.\n",
    "\n",
    "<img src=\"../../../assets/physics_bulldozer.gif\" alt=\"image info\" width=\"500px\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f90cd2f2-584e-4d9f-bcca-220d6959b66f",
   "metadata": {},
   "source": [
    "## Installation and Requirements\n",
    "For splat rendering, we will be relying on a specific version of [INRIA's splatting and rasterization code](https://github.com/graphdeco-inria/gaussian-splatting). In the setup below, make sure the paths and packages are set correctly to allow importing inria code into the notebook.\n",
    "\n",
    "We have recently tested this notebook with the following environment. Please follow [Kaolin Installation docs](https://kaolin.readthedocs.io/en/latest/notes/installation.html) to install Kaolin. \n",
    "- python 3.11.10\n",
    "- cuda 12.4\n",
    "- pytorch 2.5.1\n",
    "- setuptools 70.1.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18fa3db8-ceae-41bf-947a-fb9060b17af0",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Install necessary packages\n",
    "!pip install -q plyfile k3d matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3aa22b4-5df4-49bc-aae7-f0190341f8b3",
   "metadata": {},
   "source": [
    "### Import Kaolin Library and Other Requirements"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73c496c5-1768-4990-bc1a-b9cdad751066",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import ipywidgets\n",
    "import json\n",
    "import kaolin\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import logging\n",
    "import sys\n",
    "import time\n",
    "import threading  \n",
    "import k3d\n",
    "from pathlib import Path\n",
    "from functools import partial\n",
    "import warp as wp\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "\n",
    "from IPython.display import display\n",
    "from ipywidgets import Button, HBox, VBox\n",
    "\n",
    "logging.basicConfig(level=logging.INFO, stream=sys.stdout, format=\"%(asctime)s|%(levelname)8s| %(message)s\")\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "def log_tensor(t, name, **kwargs):\n",
    "    \"\"\" Debugging util, e.g. call: log_tensor(t, 'my tensor', print_stats=True) \"\"\"\n",
    "    logger.info(kaolin.utils.testing.tensor_info(t, name=name, **kwargs)) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13b35d07-324e-4030-8ec8-b0d5e8342daa",
   "metadata": {},
   "source": [
    "### Import local Gaussian utils for this notebook\n",
    "\n",
    "In order to deform the Gaussians during simulation, we define a couple functions in a utility file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "056dda33-6ea8-434f-a7e5-1ddacfd61f86",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gaussian_utils import transform_gaussians_lbs, pad_transforms, PHYS_NOTEBOOKS_DIR"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ed8318c-2fbd-472a-af77-dec18bd43eac",
   "metadata": {},
   "source": [
    "### Setting up Inria Gaussian Splatting Codebase\n",
    "\n",
    "This will clone and build the Gaussian renderer in a subfolder relative to this notebook: `examples/tutorial/physics/inria/`. If the build fails, you may need to set `REBUILD_INRIA=True`, fix issues and rerun this cell.\n",
    "\n",
    "**Note:** We have occasionally run into the following [bug](https://github.com/graphdeco-inria/gaussian-splatting/issues/373), which requires adding `import <float.h>` to the imports in `examples/tutorial/physics/inria/gaussian-splatting/submodules/simple-knn/simple_knn.cu`. \n",
    "\n",
    "INRIA's Gaussian Splatting is not a package. Once it's built, this block will `cd` into `..../kaolin/examples/tutorial/physics/inria/gaussian-splatting directory` in order to import gaussian rendering utilities. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f3ed4d1-4ac6-4b3c-a815-5cbdeaee0837",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#### Setup and Installation ###\n",
    "\n",
    "REBUILD_INRIA = False\n",
    "inria_path = os.path.join(PHYS_NOTEBOOKS_DIR, 'inria', 'gaussian-splatting')\n",
    "if REBUILD_INRIA or not os.path.isdir(inria_path):\n",
    "    logger.info(f'Cloning and building inria gaussian-splatting in {inria_path}')\n",
    "    %cd {PHYS_NOTEBOOKS_DIR}\n",
    "\n",
    "    ### Create an inria folder\n",
    "    %mkdir inria\n",
    "    %cd inria\n",
    "\n",
    "    ### Clone the repo recursively\n",
    "    !git clone --recursive https://github.com/graphdeco-inria/gaussian-splatting.git    \n",
    "\n",
    "    ### Install the submodules\n",
    "    %cd gaussian-splatting\n",
    "    !git checkout --recurse-submodules 472689c\n",
    "    !pip install submodules/diff-gaussian-rasterization\n",
    "    !pip install submodules/simple-knn\n",
    "else:\n",
    "    logger.info(f'Inria gaussian-splatting already exists; cd {inria_path}')\n",
    "    %cd {inria_path}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b0110e9-3e42-4b34-86de-25e5bfe34bf3",
   "metadata": {},
   "source": [
    "### Import Inria Gaussian Splat rendering utils\n",
    "\n",
    "**If you get a `module not found` error, check your paths**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23f40655-f21d-4619-a23d-e3fa2c86d1ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Gaussian splatting dependencies\n",
    "from utils.graphics_utils import focal2fov\n",
    "from utils.system_utils import searchForMaxIteration\n",
    "from gaussian_renderer import render, GaussianModel\n",
    "from scene.cameras import Camera as GSCamera\n",
    "from utils.general_utils import strip_symmetric, build_scaling_rotation\n",
    "%pwd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f652d938-bcfb-498d-9407-8563f1645227",
   "metadata": {},
   "source": [
    "## Download Splat Models from AWS\n",
    "Lets grab two pre-trained 3D Gaussian Splat models from AWS.\n",
    "We can unzip and set the splat model path below to the correct `.ply` file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5154170-a133-4df5-816a-1fdeb87788ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download and unzip the nerfsynthetic bulldozer\n",
    "!if test -d output/dozer; then echo \"Pretrained bulldozer splats already exist.\"; else wget https://nvidia-kaolin.s3.us-east-2.amazonaws.com/data/dozer.zip -P output/; unzip output/dozer.zip -d output/; fi;\n",
    "model_path = 'output/dozer/point_cloud/iteration_30000/point_cloud.ply'\n",
    "\n",
    "# Download and unzip the doll splat, captured and trained by the Kaolin team (please cite Kaolin if you use this model)\n",
    "!if test -d output/doll; then echo \"Pretrained doll splats already exist.\"; else wget https://nvidia-kaolin.s3.us-east-2.amazonaws.com/data/doll.zip -P output/; unzip output/doll.zip -d output/; fi;\n",
    "model_path2 = 'output/doll/point_cloud/iteration_30000/point_cloud.ply' "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74dba058-4a2d-47f7-ada4-671a26a3cad5",
   "metadata": {},
   "source": [
    "### Load 3D Gaussian Splat Models\n",
    "\n",
    "After the setup, we can load and use Kaolin to display the splat model within the Jupyter notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad5a0418-ba7f-4b84-9833-72093f120cf5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "class PipelineParamsNoparse:\n",
    "    \"\"\" Same as PipelineParams but without argument parser. \"\"\"\n",
    "    def __init__(self):\n",
    "        self.convert_SHs_python = False\n",
    "        self.compute_cov3D_python = False #True # covariances will be updated during simulation\n",
    "        self.debug = False\n",
    "\n",
    "def load_model(model_path, sh_degree=3, iteration=-1):\n",
    "    # Load guassians\n",
    "    gaussians = GaussianModel(sh_degree)\n",
    "    gaussians.load_ply(model_path)\n",
    "    logger.info(f'Loaded {gaussians.get_xyz.shape[0]} gaussians from {model_path}')\n",
    "    return gaussians\n",
    "\n",
    "gaussians = load_model(model_path)\n",
    "gaussians2 = load_model(model_path2)\n",
    "pipeline = PipelineParamsNoparse()\n",
    "background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\") # Set white bg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14852d04-c7b5-4d64-ae16-91f0c6e642bc",
   "metadata": {},
   "source": [
    "## Interactive Rendering Using Kaolin Visualizer\n",
    "\n",
    "In order to easily view splats in the notebook, let's set up Gaussian Splat rendering using Kaolin camera conventions.\n",
    "You should be able to see the rendering below this cell and to control the camera with your left mouse button."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1005f2fc-f59d-430c-a4c5-70d7894edca8",
   "metadata": {},
   "outputs": [],
   "source": [
    "resolution = 512\n",
    "default_cam = kaolin.render.camera.Camera.from_args(\n",
    "        eye=torch.ones((3,)) * 2, at=torch.zeros((3,)), up=torch.tensor([0., 0., 1.]),\n",
    "        fov=torch.pi * 45 / 180, height=resolution, width=resolution)\n",
    "\n",
    "class GaussianRenderer:\n",
    "    \"\"\" Define a rendering closure. \"\"\"\n",
    "    def __init__(self, gaussians, downscale_factor=1):\n",
    "        self.gaussians = gaussians\n",
    "        self.downscale_factor = int(downscale_factor)\n",
    "\n",
    "    def downscale_camera(self, in_cam):\n",
    "        lowres_cam = copy.deepcopy(in_cam)\n",
    "        lowres_cam.width = in_cam.width // self.downscale_factor\n",
    "        lowres_cam.height = in_cam.height // self.downscale_factor\n",
    "        return lowres_cam\n",
    "\n",
    "    def __call__(self, camera):\n",
    "        if self.downscale_factor > 1:\n",
    "            camera = self.downscale_camera(camera)\n",
    "        # Convert kaolin camera to inria gaussian-splatting camera\n",
    "        cam = kaolin.render.camera.kaolin_camera_to_gsplats(camera, GSCamera)\n",
    "        # Render gaussians using the inria rendering utilities\n",
    "        render_res = render(cam, self.gaussians, pipeline, background)\n",
    "        rendering = render_res[\"render\"]\n",
    "        return (torch.clamp(rendering.permute(1, 2, 0), 0, 1) * 255).to(torch.uint8).detach().cpu()\n",
    "\n",
    "static_scene_viz = kaolin.visualize.IpyTurntableVisualizer(\n",
    "    resolution, resolution, copy.deepcopy(default_cam), GaussianRenderer(gaussians), \n",
    "    focus_at=None, world_up_axis=2, max_fps=12, img_quality=75, img_format='JPEG')\n",
    "static_scene_viz.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0703b2d-4944-4f7c-8197-be14011a74c7",
   "metadata": {},
   "source": [
    "## Creating and Training Simplicits Objects from Points\n",
    "[Simplicits](https://research.nvidia.com/labs/toronto-ai/simplicits/) is a mesh-free, representation-agnostic method for simulating elastic deformations. We can use it to simulate Gaussian Splats at interactive rates within the Jupyter notebook. In order to simulate any point-sampled geometry, such as splats, Simplicits first\n",
    "_trains_ an object specific weight function representing the reduced degrees of freedom for the object. The physics solver then uses this reduced space to solve for deformations during simulation.\n",
    "\n",
    "Next, let's use the Simplicits API within Kaolin to create, train and simulate splat objects.\n",
    "\n",
    "First, let's set some material parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfe06442-79cd-45c1-92c0-78395a847bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Physics material parameters (use approximated values, or look them up online)\n",
    "# We'll create a few presets that can be used\n",
    "youngs_modulus_presets = {\"softest\": 2000, \"soft\": 21000, \"medium\": 1e6, \"stiff\": 1e7}\n",
    "soft_youngs_modulus = youngs_modulus_presets[\"soft\"]  # we will use this for training\n",
    "poisson_ratio = 0.45\n",
    "rho = 100  # kg/m^3\n",
    "approx_volume = 3  # m^3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02e574e8-8d2f-4cf3-a19d-9f3cda7fcc9d",
   "metadata": {},
   "source": [
    "### Sampling Within Splat Volume\n",
    "\n",
    "Because splats tend to occupy the surface of the object, they provide poor sampling of the object's interior. This can affect the quality of the learned reduced space. To sample within the splat volume, we will use Kaolin's utility `kaolin.ops.gaussian.sample_points_in_volume`.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d2bd1c6-727e-4183-bf80-618d1283796d",
   "metadata": {},
   "outputs": [],
   "source": [
    "densified_pos = kaolin.ops.gaussian.sample_points_in_volume(\n",
    "    xyz=gaussians.get_xyz.detach(), \n",
    "    scale=gaussians.get_scaling.detach(),\n",
    "    rotation=gaussians.get_rotation.detach(),\n",
    "    opacity=gaussians.get_opacity.detach(),\n",
    "    clip_samples_to_input_bbox=False\n",
    ")\n",
    "log_tensor(gaussians.get_xyz, 'original_pos', print_stats=True)\n",
    "log_tensor(densified_pos, 'densified_pos', print_stats=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84988936-25a0-4033-9e5f-fa0206c7e590",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_pts_k3d(densified_pos, pos):\n",
    "    plot = k3d.plot()\n",
    "    plot += k3d.points(densified_pos.detach().cpu().numpy(), point_size=0.01, color=0x00ff00)\n",
    "    plot += k3d.points(pos.detach().cpu().numpy(), point_size=0.02, color=0xff0000)\n",
    "    plot.display()\n",
    "\n",
    "visualize_pts_k3d(densified_pos, gaussians.get_xyz)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a807d1eb-1be5-4464-82c8-0146a804fd36",
   "metadata": {},
   "source": [
    "### Training\n",
    "Next we create a `SimplicitsObject` and train its skinning weight functions using the volume samples, visualized above. The simulator will then use these reduced degrees of freedom to drive the simulation.\n",
    "\n",
    "**Note:** since training takes a bit of time, we cache the result and reuse it next time we run the notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34b50fe3-13d3-42a5-9326-4b5e8d7f453a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Whether to save reduced degress of freedom used by the simulator and load from cache automatically\n",
    "ENABLE_SIMPLICITS_CACHING = True # set to False to always retrain\n",
    "\n",
    "cache_dir = os.path.join(PHYS_NOTEBOOKS_DIR, 'cache')\n",
    "os.makedirs(cache_dir, exist_ok=True)\n",
    "logger.info(f'Caching trained simplicits objects in {cache_dir}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f98f86e8-5c78-4f0f-9547-d11e46a15721",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def train_or_load_simplicits_object(points, fname):\n",
    "    if not ENABLE_SIMPLICITS_CACHING or not os.path.exists(fname):\n",
    "        logger.info('Training simplicits object. This will take 2-3min... ')\n",
    "        start = time.time()\n",
    "\n",
    "        # One-liner to set up Simplicits object\n",
    "        sim_obj = kaolin.physics.simplicits.SimplicitsObject.create_trained(\n",
    "            points,  # point samples\n",
    "            soft_youngs_modulus, poisson_ratio, rho, approx_volume,  # default global values set above\n",
    "            num_samples=2048, model_layers=10, num_handles=40)\n",
    "        \n",
    "        end = time.time()\n",
    "        logger.info(f\"Ended training in {end-start} seconds\")\n",
    "\n",
    "        # We'll cache the result so we can quickly rerun the notebook.\n",
    "        torch.save(sim_obj, fname)\n",
    "        logger.info(f\"Cached training result in {fname}\")\n",
    "    else:\n",
    "        logger.info(f'Loading cached simplicits object from: {fname}')\n",
    "        sim_obj = torch.load(fname, weights_only=False)\n",
    "    return sim_obj\n",
    "\n",
    "# We'll run training on the first object's volume points\n",
    "sim_obj = train_or_load_simplicits_object(\n",
    "    densified_pos, os.path.join(cache_dir, 'simplicits_dozer.pt'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f4723ae-47f9-40cd-b421-e8355367f886",
   "metadata": {},
   "source": [
    "## Setup Simulated Scene Using Simplicits Easy API\n",
    "Lets create an empty scene with default parameters, then reset the max number of newton steps for faster runtimes.\n",
    "\n",
    "**Note:** be patient, some of the steps below take time, as we need to build matrices used during simulation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "570eba83-c02f-470d-b56a-c696798738f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene = kaolin.physics.simplicits.SimplicitsScene() # Create a default scene # default empty scene\n",
    "\n",
    "scene.max_newton_steps = 3 #Convergence might not be guaranteed at few NM iterations, but runs very fast\n",
    "scene.timestep = 0.03\n",
    "scene.newton_hessian_regularizer = 1e-5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d6f45f7-f7ff-4b9d-b27a-0a23a4d13f70",
   "metadata": {},
   "source": [
    "Now we add our object to the scene. We use 2048 cubature points to integrate over."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c81b9030-e2de-4686-800c-0e16ab9240c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The scene copies it into an internal SimulatableObject utility class\n",
    "obj_idx = scene.add_object(sim_obj, num_qp=2048)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "199c2abc-96f5-45de-90af-220cedeb9e53",
   "metadata": {},
   "source": [
    "Lets set set gravity and floor forces on the scene"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07868367-8b1a-4d38-8bfb-8227a1bda5e0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Add gravity to the scene\n",
    "scene.set_scene_gravity(acc_gravity=torch.tensor([0, 0, 9.8]))\n",
    "# Add floor to the scene\n",
    "scene.set_scene_floor(floor_height=-0.7, floor_axis=2, floor_penalty=1000, flip_floor=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d560e57-f5fc-4afc-9999-b259aa20f2f0",
   "metadata": {},
   "source": [
    "We can play around with the material parameters of the object, indicated via object_idx"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4368abe-41ac-44f1-9433-6751df4495e8",
   "metadata": {},
   "source": [
    "## Simulating and Interactive Visualizing \n",
    "\n",
    "That's it! We are ready to simulate. Let's just make sure we can visualize the simulation as it is running."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16e55acb-e698-4e09-939c-2cb37da343e8",
   "metadata": {},
   "source": [
    "### Handling Splat Deformation\n",
    "\n",
    "As the splats deform, we must update their attributes using the transforms predicted by the simulator.\n",
    "For this, we will need the reduced degrees of freedom and ability to apply linear blend skinning to splats. These utilities can be found in `gaussian_utils.py` relative to this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39058a6f-9996-4dbc-b220-98d9bec3242e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We will save undeformed Gaussian properties, so that we can properly transform and reset them during simulation.\n",
    "rest_xyz = gaussians._xyz.clone()\n",
    "rest_rot = gaussians._rotation.clone()\n",
    "rest_scales = gaussians._scaling.clone()\n",
    "\n",
    "# Precompute learning skinning weights for all splats\n",
    "skinning_weights = sim_obj.skinning_weight_function(rest_xyz)\n",
    "\n",
    "def dozer_to_timestep(transforms):\n",
    "    global gaussians\n",
    "    gaussians._xyz, gaussians._rotation, gaussians._scaling = \\\n",
    "        transform_gaussians_lbs(rest_xyz, rest_rot, rest_scales, skinning_weights, transforms)\n",
    "    \n",
    "# Reset to rest pose\n",
    "def reset_single_object_sim():\n",
    "    global gaussians\n",
    "    scene.reset_scene()\n",
    "    dozer_to_timestep(scene.get_object_transforms(obj_idx))\n",
    "\n",
    "reset_single_object_sim()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32699363-7667-42a7-b71a-0e8cd639cc21",
   "metadata": {},
   "source": [
    "### Threading\n",
    "\n",
    "We will run simulation in a separate thread, so it is possible to interact with the viewer as the simulation is running (in fact, it's encouraged). We'll reuse these utils for this and the multi-object simulation below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e046e4-a8f6-4f3b-a6f4-761176d30f51",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_thread_open = False\n",
    "sim_thread = None\n",
    "\n",
    "def wait_for_simulation(visualizer):\n",
    "    global sim_thread_open, sim_thread\n",
    "    with visualizer.out:\n",
    "        if sim_thread_open:\n",
    "            sim_thread.join()\n",
    "            sim_thread_open = False\n",
    "    \n",
    "def start_simulation(sim_function, visualizer):\n",
    "    wait_for_simulation(visualizer)\n",
    "    \n",
    "    global sim_thread_open, sim_thread\n",
    "    with visualizer.out:\n",
    "        sim_thread_open = True\n",
    "        sim_thread = threading.Thread(target=sim_function, daemon=True)\n",
    "        sim_thread.start()\n",
    "\n",
    "def reset_simulation(reset_function, visualizer):\n",
    "    with visualizer.out:\n",
    "        reset_function()\n",
    "    visualizer.render_update()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1b0f181-4917-4735-b2eb-196880e3a602",
   "metadata": {},
   "source": [
    "### Simulation: Let's bring everything together!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08e353fd-1d27-4ddb-ad11-4c77fb0f8c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_sim_iterations = 100\n",
    "reset_single_object_sim()\n",
    "\n",
    "def single_object_sim():\n",
    "    for s in range(num_sim_iterations):\n",
    "        with sim_visualizer.out:\n",
    "            scene.run_sim_step()\n",
    "            print(\".\", end=\"\")\n",
    "            with torch.no_grad():\n",
    "                dozer_to_timestep(scene.get_object_transforms(obj_idx))\n",
    "        sim_visualizer.render_update()\n",
    "\n",
    "resolution = 512\n",
    "sim_visualizer = kaolin.visualize.IpyTurntableVisualizer(\n",
    "    resolution, resolution, copy.deepcopy(default_cam),\n",
    "    GaussianRenderer(gaussians), fast_render=GaussianRenderer(gaussians, 8),\n",
    "    focus_at=torch.tensor([0, 0, 0.0]),\n",
    "    world_up_axis=2, max_fps=12, img_quality=75, img_format='JPEG')\n",
    "\n",
    "buttons = [Button(description=x) for x in\n",
    "           ['Run Sim', 'Reset']]\n",
    "buttons[0].on_click(lambda e: start_simulation(single_object_sim, sim_visualizer))\n",
    "buttons[1].on_click(lambda e: reset_simulation(reset_single_object_sim, sim_visualizer))\n",
    "\n",
    "sim_visualizer.render_update()\n",
    "display(VBox([HBox([sim_visualizer.canvas, VBox(buttons)]), sim_visualizer.out]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37a360af-788e-459e-837e-9f6ca193b511",
   "metadata": {},
   "source": [
    "# Part 2: Multiple Objects and Collisions\n",
    "\n",
    "It's time to make this simulation more exciting. Let's train and add the second object that we loaded above to the simulation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ecf0011-d27b-4431-a1fb-40924681eabe",
   "metadata": {},
   "source": [
    "### Train Second Simplicits Object\n",
    "\n",
    "As before, we will sample and visualizer points in the object volume. Then, we'll train and cache a simplicits object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca85da53-4ecf-44ad-bbdb-ec3b3b8fe323",
   "metadata": {},
   "outputs": [],
   "source": [
    "densified_pos2 = kaolin.ops.gaussian.sample_points_in_volume(\n",
    "    xyz=gaussians2.get_xyz.detach(), \n",
    "    scale=gaussians2.get_scaling.detach(),\n",
    "    rotation=gaussians2.get_rotation.detach(),\n",
    "    opacity=gaussians2.get_opacity.detach(),\n",
    "    clip_samples_to_input_bbox=False\n",
    ")\n",
    "log_tensor(gaussians2.get_xyz, 'original_pos', print_stats=True)\n",
    "log_tensor(densified_pos2, 'densified_pos', print_stats=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bae973ca-64fa-4ebf-850f-badd3fa3c380",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_pts_k3d(densified_pos2, gaussians2.get_xyz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2579b5b-c7ba-4f25-bf0d-8363fcfad9e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We'll run training on the second object's volume points\n",
    "sim_obj2 = train_or_load_simplicits_object(\n",
    "    densified_pos2, os.path.join(cache_dir, 'simplicits_doll.pt'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e55f2932-afd7-49a6-91fc-f23772bf32a6",
   "metadata": {},
   "source": [
    "### Set Up New Scene\n",
    "\n",
    "We'll set up a new scene to make sure the previous simulation cell is still functional."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a60794-c0ef-40e8-bd0e-814d193fbf4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene2 = kaolin.physics.simplicits.SimplicitsScene() # Create a default scene # default empty scene\n",
    "\n",
    "scene2.max_newton_steps = 3 #Convergence might not be guaranteed at few NM iterations, but runs very fast\n",
    "scene2.timestep = 0.03\n",
    "scene2.newton_hessian_regularizer = 1e-5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b307cd08-334d-4dd7-a61f-2344430aa4c7",
   "metadata": {},
   "source": [
    "We'll add 2 objects this time, offsetting the doll in the z direction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9675c1d-42a5-4267-a439-6aaf14769da9",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene2_obj_idx = scene2.add_object(sim_obj, \n",
    "                                   num_qp=2048)\n",
    "\n",
    "scene2_obj_idx2 = scene2.add_object(sim_obj2, \n",
    "                                    num_qp=2048,\n",
    "                                   init_transform=torch.tensor([[1,0,0,0],\n",
    "                                                                [0,1,0,0],\n",
    "                                                                [0,0,1,1], \n",
    "                                                                [0,0,0,1]], dtype=torch.float32, \n",
    "                                                               device=gaussians.get_xyz.device)) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5d97a00-319a-4b6d-8edf-8ca45dee6e63",
   "metadata": {},
   "source": [
    "We'll set up forces as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddcb1608-ffc5-4f4e-b073-8729fe5f1ad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add gravity to the scene\n",
    "scene2.set_scene_gravity(acc_gravity=torch.tensor([0, 0, 9.8]))\n",
    "# Add floor to the scene\n",
    "scene2.set_scene_floor(floor_height=-0.7, floor_axis=2, floor_penalty=1000, flip_floor=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "691e27be-a266-408d-8e04-5e15dadbcfa4",
   "metadata": {},
   "source": [
    "### Enable Collisions (new!)\n",
    "\n",
    "We will enable inter-object collisions here. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a47be9b-8477-487f-9065-0ab18743fe9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene2.enable_collisions(collision_particle_radius=0.1, # radius of each collision particle - energy starts accumulating at r\n",
    "                        detection_ratio=1.5, # radius * detection ratio is the area that is searched for potential contact\n",
    "                        impenetrable_barrier_ratio=0.25, # radius * barrier is the distance at which energy is infinite\n",
    "                        collision_penalty=1000.0, # coefficient of collision energy, force, gradient\n",
    "                        max_contact_pairs=10000, # the maximum number of particle contact pairs to allow\n",
    "                        friction=0.5, # friction coefficient\n",
    "                    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d34b0a72-f19b-4caf-82e9-f6343787d06d",
   "metadata": {},
   "source": [
    "### Handle Deforming and Rendering Multiple Gaussians\n",
    "\n",
    "Because the inria render is not set up to render multi-object scenes, we need to do a little work in order to visualize the simulation. Let's concatenate both objects into a single GaussianModel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c7107bf-1da1-4d51-b446-07400015c591",
   "metadata": {},
   "outputs": [],
   "source": [
    "reset_single_object_sim()\n",
    "\n",
    "combined_gaussians = GaussianModel(sh_degree=3)\n",
    "combined_gaussians._xyz = torch.cat([\n",
    "    gaussians._xyz, gaussians2._xyz\n",
    "], dim=0)\n",
    "combined_gaussians._scaling = torch.cat([\n",
    "    gaussians._scaling, gaussians2._scaling\n",
    "], dim=0)\n",
    "combined_gaussians._rotation = torch.cat([\n",
    "    gaussians._rotation, gaussians2._rotation\n",
    "], dim=0)\n",
    "combined_gaussians._opacity = torch.cat([\n",
    "    gaussians._opacity, gaussians2._opacity\n",
    "], dim=0)\n",
    "combined_gaussians._features_dc = torch.cat([\n",
    "    gaussians._features_dc, gaussians2._features_dc\n",
    "], dim=0)\n",
    "combined_gaussians._features_rest = torch.cat([\n",
    "    gaussians._features_rest, gaussians2._features_rest\n",
    "], dim=0)\n",
    "\n",
    "# Save rest state of the combined model\n",
    "combined_rest_xyz = combined_gaussians._xyz.clone()\n",
    "combined_rest_rot = combined_gaussians._rotation.clone()\n",
    "combined_rest_scales = combined_gaussians._scaling.clone()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7f0409f-31a4-407b-a0cb-228353e11536",
   "metadata": {},
   "source": [
    "Let's make sure we can deform both objects using the learned degrees of freedom, which the Simplicits simulator is using to predict deformations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c6f1134-e50d-4240-bf0d-b4a130d0a46f",
   "metadata": {},
   "outputs": [],
   "source": [
    "skinning_weights2 = sim_obj2.skinning_weight_function(gaussians2.get_xyz)\n",
    "_stacked_skinning_weights = [skinning_weights, skinning_weights2]\n",
    "combined_skinning_weights = torch.cat([torch.block_diag(*_stacked_skinning_weights)])\n",
    "\n",
    "def combined_to_timestep(warp_z):\n",
    "    global combined_gaussians\n",
    "    # TODO: switch to using scene.get_object_transforms\n",
    "    obj_tfms = wp.to_torch(warp_z, requires_grad=False).reshape((-1, 3, 4))\n",
    "    transforms = pad_transforms(obj_tfms).unsqueeze(0)\n",
    "    combined_gaussians._xyz, combined_gaussians._rotation, combined_gaussians._scaling = \\\n",
    "        transform_gaussians_lbs(combined_rest_xyz, combined_rest_rot, combined_rest_scales, combined_skinning_weights, transforms)\n",
    "\n",
    "def reset_multi_object_sim():\n",
    "    global combined_gaussians\n",
    "    scene2.reset_scene()\n",
    "    combined_to_timestep(scene2.sim_z)\n",
    "\n",
    "reset_multi_object_sim()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19c5dc40-04ae-4e04-95eb-36bdf7dbae14",
   "metadata": {},
   "source": [
    "### Simulate and Visualize\n",
    "\n",
    "Now we are ready to run the simulation and visualize it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eab0e70-04a2-459a-82f6-11bc48b07449",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_sim_iterations = 100\n",
    "reset_multi_object_sim()\n",
    "\n",
    "def run_one_multisim_step():\n",
    "    scene2.run_sim_step()\n",
    "    with torch.no_grad():\n",
    "        combined_to_timestep(scene2.sim_z)\n",
    "    multi_sim_visualizer.render_update()\n",
    "\n",
    "def multi_object_sim():\n",
    "    for s in range(num_sim_iterations):\n",
    "        with multi_sim_visualizer.out:\n",
    "            print(\".\", end=\"\")\n",
    "            run_one_multisim_step()\n",
    "\n",
    "multi_cam = kaolin.render.camera.Camera.from_args(\n",
    "        eye=torch.tensor([3.0, 2.0, 3.0]), at=torch.zeros((3,)), up=torch.tensor([0., 0., 1.]),\n",
    "        fov=torch.pi * 45 / 180, height=resolution, width=resolution)\n",
    "\n",
    "resolution = 700\n",
    "multi_sim_visualizer = kaolin.visualize.IpyTurntableVisualizer(\n",
    "    resolution, resolution, multi_cam,\n",
    "    GaussianRenderer(combined_gaussians), fast_render=GaussianRenderer(combined_gaussians, 8),\n",
    "    focus_at=torch.tensor([0, 0, -0.7]),\n",
    "    world_up_axis=2, max_fps=12, img_quality=75, img_format='JPEG')\n",
    "\n",
    "buttons = [Button(description=x) for x in\n",
    "           ['Run Sim', 'Reset']]\n",
    "buttons[0].on_click(lambda e: start_simulation(multi_object_sim, multi_sim_visualizer))\n",
    "buttons[1].on_click(lambda e: reset_simulation(reset_multi_object_sim, multi_sim_visualizer))\n",
    "\n",
    "run_one_multisim_step()\n",
    "multi_sim_visualizer.render_update()\n",
    "display(VBox([HBox([multi_sim_visualizer.canvas, VBox(buttons)]), multi_sim_visualizer.out]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
